import itertools

import torch
from collections import OrderedDict
from copy import deepcopy
import logging
import omegaconf

log = logging.getLogger(__name__)


# def load_all_models(trained_clients_list, clients, models):
#     # TODO: the updateing of state dict keys work for lenet only and not resnet, this NEEDS a FIX!!
#
#     # Load models weights
#     log.info(type(trained_clients_list))
#     # if type(all_clients) == tuple or type(all_clients) == list or type(all_clients) == omegaconf.ListConfig:
#     all_models = []
#     for t in trained_clients_list:
#             log.info('Loading the clients')
#             model_path = clients[f"client_{t}"].model_path
#             model = load_model(models, model_path)
#             all_models.append(deepcopy(model))
#         # teacher_model = teachers_models
#         # Assert that they are different
#     for model1, model2 in itertools.combinations(all_models, r=2):
#             equal = True
#             for (k1, v1), (k2, v2) in zip(
#                     model1.state_dict().items(), model2.state_dict().items()
#             ):
#                 assert not torch.equal(v1, v2)
#                 if not torch.equal(v1, v2):
#                     equal = False
#             assert not equal
#
#     return all_models

def load_models_dir_exp(model, pathdate, from_round=99, extra=""):
    if not extra:
        extra = ""
    # doing this with static model path links, for a certain experiment
    all_models = []
    for i in range(10):
        # moh-sands/KDN_N_exp/3gtkfi0e   #Dir0.1 5 local steps
        print(
            f"Dirichlet alpha experiment, loading the models that was trained with FedAvg for {from_round + 1} rounds, to continue training")
        model_path = f"/data/fat/alballns/outputs/transfer_exp/opt-adam_from-C[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]-to-C[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]_dm-transfer_dm/{pathdate}/models/client-{i}/{extra}.ckpt"
        print(f"model_path: {model_path}")
        i_model = load_model3(model, model_path, torch.device('cpu'))
        all_models.append(deepcopy(i_model))

    return all_models


def load_model_cc_exp(model, pathdate, round=99):
    # doing this with static model path links, for a certain experiment
    if type(round) == int:
        print(
            f"CC experiment, loading the model that was trained with FedAvg until round {round}, to continue training")
        model_path = f"/data/fat/alballns/outputs/transfer_exp/{pathdate}/FedAvg-model-round{round}.ckpt"
        print(f"model_path: {model_path}")
        m = load_model3(model, model_path, torch.device('cpu'))
    else:  # if list
        all_models = []
        for r in round:
            # moh-sands/KDN_N_exp/3gtkfi0e   #Dir0.1 5 local steps
            print(f"CC experiment, loading the models.. ")
            model_path = f"/data/fat/alballns/outputs/transfer_exp/{pathdate}/FedAvg-model-round{r}.ckpt"
            print(f"model_path: {model_path}")
            i_model = load_model3(model, model_path, torch.device('cpu'))
            all_models.append(deepcopy(i_model))
            m = all_models

    return m


## old version
# def load_models_dir_exp(model,pathdate,from_round=99,extra=""):
#     if not extra:
#         extra = ""
#     # doing this with static model path links, for a certain experiment
#     all_models = []
#     for i in range(10):
#         # moh-sands/KDN_N_exp/3gtkfi0e   #Dir0.1 5 local steps
#         print(f"Dirichlet alpha experiment, loading the models that was trained with FedAvg for {from_round+1} rounds, to continue training")
#         model_path = f"/data/fat/alballns/outputs/transfer_exp/opt-adam_from-C[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]-to-C[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]_dm-transfer_dm/{pathdate}/models/client-{i}/model-round{from_round}{extra}.ckpt"
#         i_model = load_model3(model, model_path, torch.device('cpu'))
#         all_models.append(deepcopy(i_model))
#
#     return all_models


def load_learner_model(client, clients, model):
    # TODO: the updateing of state dict keys work for lenet only and not resnet, this NEEDS a FIX!!
    # Load Learner model weights
    model_path = clients[f"client_{client}"].model_path
    model = load_model(model, model_path, torch.device('cpu'))
    return model


def load_teacher_models(teachers, clients, model):
    log.info(type(teachers))
    if type(teachers) == tuple or type(teachers) == list or type(teachers) == omegaconf.ListConfig:
        teachers_models = []
        for t in teachers:
            log.info('Loading multiple teachers')
            model_path = clients[f"client_{t}"].model_path
            teacher_model = load_model(model, model_path, torch.device('cpu'))
            teachers_models.append(deepcopy(teacher_model))
        teacher_model = teachers_models

    # Normal case where there is one teacher
    else:
        model_path = clients[f"client_{teachers}"].model_path
        teacher_model = load_model(model, model_path, torch.device('cpu'))
    return teacher_model


def load_models(clients_ids, clients, model, special_load=False):
    log.info(type(clients_ids))
    if type(clients_ids) == tuple or type(clients_ids) == list or type(clients_ids) == omegaconf.ListConfig:
        all_models = []
        for t in clients_ids:
            log.info('Loading multiple teachers')
            model_path = clients[f"client_{t}"].model_path
            if special_load:
                i_model = load_model3(model, model_path, torch.device('cpu'))
            else:
                i_model = load_model(model, model_path, torch.device('cpu'))
            all_models.append(deepcopy(i_model))
        models = all_models

    # Normal case where there is one teacher
    else:
        model_path = clients[f"client_{clients_ids}"].model_path
        models = load_model(model, model_path, torch.device('cpu'))
    return models


def load_saved_model(model, model_path):
    log.info(f'Loading the model from {model_path}')
    i_model = load_model(model, model_path, torch.device('cpu'))
    return i_model


def load_models_defKt(clients_ids, clients, model, special_load=False):
    log.info(type(clients_ids))
    # if type(clients_ids) == tuple or type(clients_ids) == list or type(clients_ids) == omegaconf.ListConfig:
    all_models = []
    for t in clients_ids:
        log.info('Loading multiple teachers')
        model_path = clients[f"client_{t}"].model_path
        i_model = load_model4(model, model_path, torch.device('cpu'))
        all_models.append(deepcopy(i_model))
    models = all_models

    # Normal case where there is one teacher
    # else:
    #     model_path = clients[f"client_{clients_ids}"].model_path
    #     models = load_model(model, model_path, torch.device('cpu'))
    return models


def load_client_models(num_clients, learner_client, teacher_client, clients, learner_model, teacher_model, device):
    # TODO: the updateing of state dict keys work for lenet only and not resnet, this NEEDS a FIX!!
    # Load Learner model weights
    if num_clients > 1 and learner_client != -1:
        model_path = clients[f"client_{learner_client}"].model_path
        learner_model = load_model(learner_model, model_path, device)
    elif learner_client == -1:

        log.info("Transfering to untrained learner!")
    # Load Teacher model weights
    # In case there are mutliple teachers
    log.info(type(teacher_client))
    if type(teacher_client) == tuple or type(teacher_client) == list or type(teacher_client) == omegaconf.ListConfig:
        teachers_models = []
        for t in teacher_client:
            log.info('Loading multiple teachers')
            model_path = clients[f"client_{t}"].model_path
            teacher_model = load_model(teacher_model, model_path, device)
            teachers_models.append(deepcopy(teacher_model))
        teacher_model = teachers_models
        # Assert that they are different
        for model1, model2 in itertools.combinations(teacher_model, r=2):
            equal = True
            # TODO: uncomment
            # for (k1, v1), (k2, v2) in zip(
            #         model1.state_dict().items(), model2.state_dict().items()
            # ):
            #     assert not torch.equal(v1, v2)
            #     if not torch.equal(v1, v2):
            #         equal = False
            # assert not equal
    # Normal case where there is one teacher
    else:
        model_path = clients[f"client_{teacher_client}"].model_path
        teacher_model = load_model(teacher_model, model_path, device)
    return learner_model, teacher_model


def load_learner_client(num_clients, learner_client, clients, learner_model, device):
    # TODO: the updateing of state dict keys work for lenet only and not resnet, this NEEDS a FIX!!
    # Load Learner model weights
    if num_clients > 1 and learner_client != -1:
        model_path = clients[f"client_{learner_client}"].model_path
        learner_model = load_model(learner_model, model_path, device)
    elif learner_client == -1:
        log.info("Transfering to untrained learner!")

    return learner_model


def load_model(model, model_path, device):
    print(f"model_path: {model_path}")  # TODO: remove it
    loaded_pl_model = torch.load(model_path, map_location=torch.device('cpu'))
    # print("type(loaded_pl_model),type(model)",type(loaded_pl_model),type(model))
    # log.info(f"len(model.state_dict().items()) , len(loaded_pl_model['state_dict'].items()) {len(model.state_dict().items()), len(loaded_pl_model['state_dict'].items())}")
    # loaded_list = list(loaded_pl_model['state_dict'].items())
    # print(f"len(loaded_list): {len(loaded_list)} ")
    # for i in range (122):
    #     print(f">>>>>>>>> loaded_pl_model: loaded_list[i],loaded_list[122+i]: 111111>>>>{loaded_list[i]}\n222222>>>>{loaded_list[122+i]}\n (first half == second half???): {loaded_list[i] == loaded_list[122+i] }")
    # # print(f">>> model.state_dict().items() {model.state_dict().items()}")
    if len(model.state_dict().items()) != len(loaded_pl_model['state_dict'].items()):
        model = load_model3(model, model_path, device)
        return model

    # else:
    # assert len(model.state_dict().items()) == len(loaded_pl_model['state_dict'].items())

    updated_state_dict = OrderedDict()
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
                                  loaded_pl_model['state_dict'].items()):
        updated_state_dict[k1] = v2
    model.load_state_dict(updated_state_dict)
    # check that the weights was loaded correctly
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
                                  loaded_pl_model['state_dict'].items()):
        assert torch.equal(
            model.state_dict()[k1].to(device),
            loaded_pl_model['state_dict'][k2].to(device)
        )
    return model


def load_model2(model, model_path):
    loaded_pl_model = torch.load(model_path)
    assert len(model.state_dict().items()) == len(loaded_pl_model['state_dict'].items())
    updated_state_dict = OrderedDict()
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
                                  loaded_pl_model['state_dict'].items()):
        updated_state_dict[k1] = v2
    model.load_state_dict(updated_state_dict)
    # check that the weights was loaded correctly
    # for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
    #                               loaded_pl_model['state_dict'].items()):
    # assert torch.equal(
    #     model.state_dict()[k1].to(device),
    #     loaded_pl_model['state_dict'][k2].to(device)
    # )
    return model


def load_model3(model, model_path, device):
    loaded_pl_model = torch.load(model_path, map_location=torch.device('cpu'))
    print(f"type(loaded_pl_model): {type(loaded_pl_model)},type(model): {type(model)}")
    log.info(
        f"len(model.state_dict().items()) , len(loaded_pl_model['state_dict'].items()) {len(model.state_dict().items()), len(loaded_pl_model['state_dict'].items())}")
    #
    # helf_model_len = int(len(loaded_pl_model['state_dict'].items())/2)
    # print(f"helf_model_len: {helf_model_len}")
    # loaded_list = list(loaded_pl_model['state_dict'].items())[helf_model_len:]
    loaded_list = list(loaded_pl_model['state_dict'].items())[122:]
    loaded_model = dict(loaded_list)
    # print(f"len(loaded_list): {len(loaded_list)} ")
    # for i in range (122):
    #     print(f">>>>>>>>> loaded_pl_model: loaded_list[i],loaded_list[122+i]: 111111>>>>{loaded_list[i]}\n222222>>>>{loaded_list[122+i]}\n (first half == second half???): {loaded_list[i] == loaded_list[122+i] }")
    # # print(f">>> model.state_dict().items() {model.state_dict().items()}")

    assert len(model.state_dict().items()) == len(loaded_model.items())

    updated_state_dict = OrderedDict()
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
                                  loaded_model.items()):
        updated_state_dict[k1] = v2
    model.load_state_dict(updated_state_dict)
    # check that the weights was loaded correctly
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
                                  loaded_model.items()):
        assert torch.equal(
            model.state_dict()[k1].to(device),
            loaded_model[k2].to(device)
        )
    return model


def load_model3(model, model_path, device):
    loaded_pl_model = torch.load(model_path, map_location=torch.device('cpu'))
    print(f"type(loaded_pl_model): {type(loaded_pl_model)},type(model): {type(model)}")
    log.info(
        f"len(model.state_dict().items()) , len(loaded_pl_model['state_dict'].items()) {len(model.state_dict().items()), len(loaded_pl_model['state_dict'].items())}")

    half_loaded_model_len = int(len(loaded_pl_model['state_dict'].items()) / 2)
    print(f"half_loaded_model_len: {half_loaded_model_len}")
    loaded_list = list(loaded_pl_model['state_dict'].items())[half_loaded_model_len:]
    loaded_model = dict(loaded_list)

    assert len(model.state_dict().items()) == len(loaded_model.items())

    updated_state_dict = OrderedDict()
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
                                  loaded_model.items()):
        updated_state_dict[k1] = v2
    model.load_state_dict(updated_state_dict)
    # check that the weights was loaded correctly
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
                                  loaded_model.items()):
        assert torch.equal(
            model.state_dict()[k1].to(device),
            loaded_model[k2].to(device)
        )
    return model


def load_model_qktD(model, model_path, device):
    loaded_state_dict = torch.load(model_path, map_location=device)
    print(f"type(loaded_state_dict): {type(loaded_state_dict)}, type(model): {type(model)}")

    if isinstance(loaded_state_dict, dict) and 'state_dict' in loaded_state_dict:
        loaded_state_dict = loaded_state_dict['state_dict']

    log.info(
        f"len(model.state_dict().items()), len(loaded_state_dict.items()): {len(model.state_dict().items()), len(loaded_state_dict.items())}")

    # loaded_list = list(loaded_state_dict.items())[122:]
    # loaded_state_dict = dict(loaded_list)

    assert len(model.state_dict().items()) == len(loaded_state_dict.items())

    updated_state_dict = OrderedDict()
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(), loaded_state_dict.items()):
        updated_state_dict[k1] = v2
    model.load_state_dict(updated_state_dict)

    # Check that the weights were loaded correctly
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(), loaded_state_dict.items()):
        assert torch.equal(model.state_dict()[k1].to(device), loaded_state_dict[k2].to(device))

    return model


# def load_model_qktD(model, model_path, device):
#     loaded_state_dict = torch.load(model_path, map_location=device)
#     print(f"type(loaded_state_dict): {type(loaded_state_dict)}, type(model): {type(model)}")
#
#     if isinstance(loaded_state_dict, dict) and 'state_dict' in loaded_state_dict:
#         loaded_state_dict = loaded_state_dict['state_dict']
#
#     log.info(f"len(model.state_dict().items()), len(loaded_state_dict.items()): {len(model.state_dict().items()), len(loaded_state_dict.items())}")
#
#     updated_state_dict = OrderedDict()
#     for (k, v) in model.state_dict().items():
#         if k in loaded_state_dict:
#             updated_state_dict[k] = loaded_state_dict[k]
#         else:
#             log.warning(f"Key {k} not found in loaded state dict. Using initial weights.")
#             updated_state_dict[k] = v
#
#     model.load_state_dict(updated_state_dict)
#
#     # Check that the weights were loaded correctly
#     for (k1, v1), (k2, v2) in zip(model.state_dict().items(), updated_state_dict.items()):
#         assert torch.equal(model.state_dict()[k1].to(device), updated_state_dict[k2].to(device))
#
#     return model


def load_model4(model, model_path, device):
    loaded_pl_model = torch.load(model_path, map_location=torch.device('cpu'))
    print(f"type(loaded_pl_model): {type(loaded_pl_model)},type(model): {type(model)}")
    log.info(
        f"len(model.state_dict().items()) , len(loaded_pl_model['state_dict'].items()) {len(model.state_dict().items()), len(loaded_pl_model['state_dict'].items())}")
    loaded_list = list(loaded_pl_model['state_dict'].items())[10:]
    loaded_model = dict(loaded_list)
    # print(f"len(loaded_list): {len(loaded_list)} ")
    # for i in range (122):
    #     print(f">>>>>>>>> loaded_pl_model: loaded_list[i],loaded_list[122+i]: 111111>>>>{loaded_list[i]}\n222222>>>>{loaded_list[122+i]}\n (first half == second half???): {loaded_list[i] == loaded_list[122+i] }")
    # # print(f">>> model.state_dict().items() {model.state_dict().items()}")

    assert len(model.state_dict().items()) == len(loaded_model.items())

    updated_state_dict = OrderedDict()
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
                                  loaded_model.items()):
        updated_state_dict[k1] = v2
    model.load_state_dict(updated_state_dict)
    # check that the weights was loaded correctly
    for (k1, v1), (k2, v2) in zip(model.state_dict().items(),
                                  loaded_model.items()):
        assert torch.equal(
            model.state_dict()[k1].to(device),
            loaded_model[k2].to(device)
        )
    return model
